package main

import (
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"regexp"
	"strings"

	"github.com/Ice3man543/nvd"
	"github.com/projectdiscovery/nuclei/v2/pkg/catalog"
)

var (
	input       = flag.String("i", "", "Templates to annotate")
	templateDir = flag.String("d", "", "Custom template directory for update")
)

func main() {
	flag.Parse()

	if *input == "" || *templateDir == "" {
		log.Fatalf("invalid input, see -h\n")
	}

	if err := process(); err != nil {
		log.Fatalf("could not process: %s\n", err)
	}
}

func process() error {
	tempDir, err := ioutil.TempDir("", "nuclei-nvd-%s")
	if err != nil {
		return err
	}
	defer os.RemoveAll(tempDir)

	client, err := nvd.NewClient(tempDir)
	if err != nil {
		return err
	}
	catalog := catalog.New(*templateDir)

	paths, err := catalog.GetTemplatePath(*input)
	if err != nil {
		return err
	}
	for _, path := range paths {
		data, err := ioutil.ReadFile(path)
		if err != nil {
			return err
		}
		getCVEData(client, path, string(data))
	}
	return nil
}

var (
	idRegex       = regexp.MustCompile("id: ([C|c][V|v][E|e]-[0-9]+-[0-9]+)")
	severityRegex = regexp.MustCompile(`severity: ([a-z]+)`)
)

func getCVEData(client *nvd.Client, filePath, data string) {
	matches := idRegex.FindAllStringSubmatch(data, 1)
	if len(matches) == 0 {
		return
	}
	cveName := matches[0][1]

	severityMatches := severityRegex.FindAllStringSubmatch(data, 1)
	if len(matches) == 0 {
		return
	}
	severityValue := severityMatches[0][1]

	// Skip if there's classification data already
	if strings.Contains(data, "classification:") {
		return
	}
	cveItem, err := client.FetchCVE(cveName)
	if err != nil {
		log.Printf("Could not fetch cve %s: %s\n", cveName, err)
		return
	}
	var cweID []string
	for _, problemData := range cveItem.CVE.Problemtype.ProblemtypeData {
		for _, description := range problemData.Description {
			cweID = append(cweID, description.Value)
		}
	}
	cvssScore := cveItem.Impact.BaseMetricV3.CvssV3.BaseScore
	cvssMetrics := cveItem.Impact.BaseMetricV3.CvssV3.VectorString

	// Perform some hacky string replacement to place the metadata in templates
	infoBlockIndexData := data[strings.Index(data, "info:"):]
	requestsIndex := strings.Index(infoBlockIndexData, "requests:")
	networkIndex := strings.Index(infoBlockIndexData, "network:")
	if requestsIndex == -1 && networkIndex == -1 {
		return
	}
	if networkIndex != -1 {
		requestsIndex = networkIndex
	}
	infoBlockData := infoBlockIndexData[:requestsIndex]
	infoBlockClean := strings.TrimRight(infoBlockData, "\n")

	newInfoBlock := infoBlockClean
	var changed bool

	if newSeverity := isSeverityMatchingCvssScore(severityValue, cvssScore); newSeverity != "" {
		changed = true
		newInfoBlock = strings.ReplaceAll(newInfoBlock, severityMatches[0][0], "severity: "+newSeverity)
		fmt.Printf("Adjusting severity for %s from %s=>%s (%.2f)\n", filePath, severityValue, newSeverity, cvssScore)
	}
	if !strings.Contains(infoBlockClean, "classification") && (cvssScore != 0 && cvssMetrics != "") {
		changed = true
		newInfoBlock = newInfoBlock + fmt.Sprintf("\n  classification:\n    cvss-metrics: %s\n    cvss-score: %.2f\n    cve-id: %s", cvssMetrics, cvssScore, cveName)
		if len(cweID) > 0 && (cweID[0] != "NVD-CWE-Other" && cweID[0] != "NVD-CWE-noinfo") {
			newInfoBlock = newInfoBlock + fmt.Sprintf("\n    cwe-id: %s", strings.Join(cweID, ","))
		}
	}
	// If there is no description field, fill the description from CVE information
	if !strings.Contains(infoBlockClean, "description:") && len(cveItem.CVE.Description.DescriptionData) > 0 {
		changed = true
		newInfoBlock = newInfoBlock + fmt.Sprintf("\n  description: %s", fmt.Sprintf("%q", cveItem.CVE.Description.DescriptionData[0].Value))
	}
	if !strings.Contains(infoBlockClean, "reference:") && len(cveItem.CVE.References.ReferenceData) > 0 {
		changed = true
		newInfoBlock = newInfoBlock + "\n  reference:"
		for _, reference := range cveItem.CVE.References.ReferenceData {
			newInfoBlock = newInfoBlock + fmt.Sprintf("\n    - %s", reference.URL)
		}
	}
	newTemplate := strings.ReplaceAll(data, infoBlockClean, newInfoBlock)
	if changed {
		_ = ioutil.WriteFile(filePath, []byte(newTemplate), 0777)
		fmt.Printf("Wrote updated template to %s\n", filePath)
	}
}

func isSeverityMatchingCvssScore(severity string, score float64) string {
	if score == 0.0 {
		return ""
	}
	var expected string

	if score >= 0.1 && score <= 3.9 {
		expected = "low"
	} else if score >= 4.0 && score <= 6.9 {
		expected = "medium"
	} else if score >= 7.0 && score <= 8.9 {
		expected = "high"
	} else if score >= 9.0 && score <= 10.0 {
		expected = "critical"
	}
	if expected != "" && expected != severity {
		return expected
	}
	return ""
}
